#!/bin/bash

. ./path.sh || exit 1;

export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7"
stage=5 # start from 0 if you need to start from data_list preparation
stop_stage=6

dir=/home/work_nfs7/xlgeng/bsmu_template/exp/salmonn_v8_lr5e_5
mkdir -p $dir
train_config=conf/train_salmonn_v8.yaml
# checkpoint=$dir/1.pt 3 checkpoint 在configs中设置。 不再这里
data_type=shard # raw or shard
num_workers=8  # 数据加载的进程数
prefetch=200

average_checkpoint=false
decode_checkpoint=$dir/0.pt
decode_checkpoint_name="0pt"
average_num=10
# decode_modes="attention_rescoring ctc_greedy_search ctc_prefix_beam_search attention"
decode_modes="salmonn_decode"


HOST_NODE_ADDR="localhost:0"
num_nodes=1

#cmvn=false   #cmvn的配置信息也转放到了configs
#do_delta=false

deepspeed_config=conf/ds_stage2.json
deepspeed_save_states="model_only"

. tools/parse_options.sh || exit 1;

echo "开始打印主要变量，这些变量有命令行传入"
echo "dir=$dir"
echo "train_config=$train_config"
echo "decode_checkpoint=$decode_checkpoint"
echo "decode_checkpoint_name=$decode_checkpoint_name"


set -e
set -u
set -o pipefail

test_sets=("aishell1" "aishell2" "SPEECHIO_ASR_ZH00000" "SPEECHIO_ASR_ZH00001" "SPEECHIO_ASR_ZH00002" "SPEECHIO_ASR_ZH00003" "SPEECHIO_ASR_ZH00004" "SPEECHIO_ASR_ZH00005" "test_meeting" "test_net")

# 采用llm的tokenizer不需要传入dict和bpemodel
dict=""
bpemodel=""
gpu_id=0
if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
  decoding_chunk_size=
  ctc_weight=0.5
  # Polling GPU id begin with index 0
  for test_set in "${test_sets[@]}"; do
  {
    echo "test this dataset: $test_set"
    test_dir=$dir/test_${decode_checkpoint_name}/${test_set}
#    wer_path=$test_dir/wer
#    if [ -e "$wer_path" ]; then
#      echo "$wer_path 文件已存在，跳过对该数据集的推理"
#      continue
#    fi
    mkdir -p $test_dir
    python wenet/bin/recognize.py --gpu $gpu_id \
      --modes $decode_modes \
      --config $dir/train.yaml \
      --data_type raw \
      --test_data data_list/test/$test_set/data.list \
      --checkpoint $decode_checkpoint \
      --beam_size 10 \
      --batch_size 1 \
      --penalty 0.0 \
      --result_dir $test_dir \
      --ctc_weight $ctc_weight \

    python tools/compute-wer.py --char=1 --v=1 \
      data_list/test/$test_set/text $test_dir/text_hyp > $test_dir/wer
    echo "$test_set has been decoded!"
  }
  done
  wait

fi
if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then
  decoding_chunk_size=
  ctc_weight=0.5
  # Polling GPU id begin with index 0
  for test_set in "${test_sets[@]}"; do
  {
    echo "compute wer this dataset: $test_set"
    test_dir=$dir/test_${decode_checkpoint_name}/${test_set}
#    wer_path=$test_dir/wer
#    if [ -e "$wer_path" ]; then
#      echo "$wer_path 文件已存在，跳过对该数据集的推理"
#      continue
#    fi
    mkdir -p $test_dir
#    python wenet/bin/recognize.py --gpu $gpu_id \
#      --modes $decode_modes \
#      --config $dir/train.yaml \
#      --data_type raw \
#      --test_data data_list/test/$test_set/data.list \
#      --checkpoint $decode_checkpoint \
#      --beam_size 10 \
#      --batch_size 1 \
#      --penalty 0.0 \
#      --result_dir $test_dir \
#      --ctc_weight $ctc_weight \

    python tools/compute-wer.py --char=1 --v=1 \
      data_list/test/$test_set/text $test_dir/text > $test_dir/wer
    echo "$test_set has been decoded!"
  }
  done
  wait

fi
